Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve priority-based ordering for recompute #20117

Closed
wants to merge 21 commits into from

Conversation

pengwa
Copy link
Contributor

@pengwa pengwa commented Mar 28, 2024

Improve priority-based ordering for recompute

Critical Path Impact in Priority-Based Topological Sort

Setting and comparing critical path impact is essential in scenarios where all nodes in the priority queue
are of low priority. This comparison helps determine which recompute node to select to unblock the backward
critical path. Consider a scenario where:
- A recompute node subgraph, NodeSubgraph-N, exists within transformer layer N (e.g., NodeSubgraph-5 in
layer 5, NodeSubgraph-3 in layer 3).
- Node-A-IN-5 within NodeSubgraph-5 depends on Node-B-IN-0 within NodeSubgraph-0.
- The priority queue contains nodes from NodeSubgraph-0 to NodeSubgraph-5.
In MemoryOptimizer recompute scenarios, we append nodes starting from NodeSubgraph-5 down to NodeSubgraph-0.
Relying solely on node index for comparison could lead to:

  1. Sequential output of nodes in NodeSubgraph-5 (sorted by ascending node index within the subgraph).
  2. Blocking of Node-A-IN-5's execution until Node-B-IN-0 is executed.
  3. Sequential output of nodes from NodeSubgraph-4 to NodeSubgraph-1.
  4. Execution of NodeSubgraph-0 nodes, allowing Node-A-IN-5 and subsequent NodeSubgraph-5 nodes to execute.
  5. Execution of remaining NodeSubgraph-0 nodes.

This process can significantly delay the execution of Node-A-IN-5, blocking other NodeSubgraph-5 nodes.
Since NodeSubgraph-5 nodes are on the critical path, triggering their dependencies timely is crucial to
ensure their execution as early as possible, ahead of other layers. This necessity led to the introduction
of critical path impact.

Defining Critical Path Impact
Critical path impact is a metric representing a node's influence on the critical path. It is determined
during MemoryOptimizer's operation as follows:
1) Sort graphs without recompute optimization to establish a baseline topological order.
2) Apply recompute optimization.
3) Identify recompute boundary nodes (recompute nodes not consumed by others).
4) For each boundary node, calculate the minimum topological order of all output nodes.
The minimum value indicates the earliest need for the recompute node's execution.
We assign std::numeric_limits<int64_t>::max() - min_topological_order as the critical path impact.
5) For other recompute nodes, assign the maximum critical path impact of their output nodes.

DEFECTs

The change on priority based topo sort, give up the priority for the recompute node orders. And it need firstly a topo sort without applying mem opt, and the first sort might affect the recompute graph a lot. Here is one example:

Some nodes that are not contributing to YieldOp is treated as a non-forward ops, so it will be scheduled after YieldOp, and is possible be scheduled early, then the dependent recompute graph will be scheduled early too. This is not always most optimized.

Motivation and Context

Mistral models cannot run user recipes when enabling recompute. The root cause is, the execution order of recompute are not correct, making the memory saving very limited.

@pengwa pengwa added the training issues related to ONNX Runtime training; typically submitted using template label Mar 28, 2024
Base automatically changed from pengwa/priority_tuning to main March 29, 2024 09:44
@pengwa pengwa marked this pull request as ready for review March 29, 2024 10:15
@pengwa
Copy link
Contributor Author

pengwa commented Apr 8, 2024

Will be split into PRs.

@pengwa pengwa closed this Apr 8, 2024
@pengwa
Copy link
Contributor Author

pengwa commented Apr 8, 2024

The first one: #20234

@pengwa pengwa deleted the pengwa/recompute_in_critical_path branch April 22, 2024 09:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
training issues related to ONNX Runtime training; typically submitted using template
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant